{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial : tensorflow2fluid转换VGG_16模型\n",
    "\n",
    "VGG_16是CV领域的一个经典模型，本文档以tensorflow/models下的[VGG_16](https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py)为例，展示如何将TensorFlow训练好的模型转换为PaddlePaddle模型。 \n",
    "### 下载预训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Download percentage 100.00%"
     ]
    }
   ],
   "source": [
    "import urllib\n",
    "import sys\n",
    "def schedule(a, b, c):\n",
    "    per = 100.0 * a * b / c\n",
    "    per = int(per)\n",
    "    sys.stderr.write(\"\\rDownload percentage %.2f%%\" % per)\n",
    "    sys.stderr.flush()\n",
    "\n",
    "url = \"http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz\"\n",
    "fetch = urllib.urlretrieve(url, \"./vgg_16.tar.gz\", schedule)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 解压下载的压缩文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tarfile\n",
    "with tarfile.open(\"./vgg_16.tar.gz\", \"r:gz\") as f:\n",
    "    file_names = f.getnames()\n",
    "    for file_name in file_names:\n",
    "        f.extract(file_name, \"./\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 保存模型为checkpoint格式\n",
    "\n",
    "tensorflow2fluid目前支持checkpoint格式的模型或者是将网络结构和参数序列化的pb格式模型，上面下载的`vgg_16.ckpt`仅仅存储了模型参数，因此我们需要重新加载参数，并将网络结构和参数一起保存为checkpoint模型\n",
    "\n",
    "**注意：下面的代码里，运行TensorFlow模型和将TensorFlow模型转换为PaddlePaddle模型，依赖TensorFlow**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from vgg_16.ckpt\n"
     ]
    }
   ],
   "source": [
    "import tensorflow.contrib.slim as slim\n",
    "from tensorflow.contrib.slim.nets import vgg\n",
    "import tensorflow as tf\n",
    "import numpy\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name=\"inputs\")\n",
    "    logits, endpoint = vgg.vgg_16(inputs, num_classes=1000, is_training=False)\n",
    "    load_model = slim.assign_from_checkpoint_fn(\"vgg_16.ckpt\", slim.get_model_variables(\"vgg_16\"))\n",
    "    load_model(sess)\n",
    "    \n",
    "    numpy.random.seed(13)\n",
    "    data = numpy.random.rand(5, 224, 224, 3)\n",
    "    input_tensor = sess.graph.get_tensor_by_name(\"inputs:0\")\n",
    "    output_tensor = sess.graph.get_tensor_by_name(\"vgg_16/fc8/squeezed:0\")\n",
    "    result = sess.run([output_tensor], {input_tensor:data})\n",
    "    numpy.save(\"tensorflow.npy\", numpy.array(result))\n",
    "    \n",
    "    saver = tf.train.Saver()\n",
    "    saver.save(sess, \"./checkpoint/model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 将模型转换为PaddlePaddle模型\n",
    "\n",
    "**注意**：部分OP在转换时，需要将参数写入文件；或者是运行tensorflow模型进行infer，获取tensor值。两种情况下均会消耗一定的时间用于IO或计算，对于后一种情况，建议转换模型时将`use_cuda`参数设为`True`，加快infer速度\n",
    "\n",
    "可以通过下面的**模型转换python脚本**在代码中设置参数，在python脚本中进行模型转换。或者一般可以通过如下的命令行方式进行转换，\n",
    "``` python\n",
    "# 通过命令行也可进行模型转换\n",
    "python tf2fluid/convert.py --meta_file checkpoint/model.meta --ckpt_dir checkpoint \\\n",
    "                           --in_nodes inputs --input_shape None,224,224,3 \\\n",
    "                           --output_nodes vgg_16/fc8/squeezed --use_cuda True \\\n",
    "                           --input_format NHWC --save_dir paddle_model\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 模型转换python脚本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Loading tensorflow model...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from checkpoint/model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from checkpoint/model\n",
      "INFO:root:Tensorflow model loaded!\n",
      "INFO:root:TotalNum:86,TraslatedNum:1,CurrentNode:inputs\n",
      "INFO:root:TotalNum:86,TraslatedNum:2,CurrentNode:vgg_16/conv1/conv1_1/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:3,CurrentNode:vgg_16/conv1/conv1_1/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:4,CurrentNode:vgg_16/conv1/conv1_2/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:5,CurrentNode:vgg_16/conv1/conv1_2/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:6,CurrentNode:vgg_16/conv2/conv2_1/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:7,CurrentNode:vgg_16/conv2/conv2_1/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:8,CurrentNode:vgg_16/conv2/conv2_2/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:9,CurrentNode:vgg_16/conv2/conv2_2/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:10,CurrentNode:vgg_16/conv3/conv3_1/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:11,CurrentNode:vgg_16/conv3/conv3_1/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:12,CurrentNode:vgg_16/conv3/conv3_2/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:13,CurrentNode:vgg_16/conv3/conv3_2/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:14,CurrentNode:vgg_16/conv3/conv3_3/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:15,CurrentNode:vgg_16/conv3/conv3_3/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:16,CurrentNode:vgg_16/conv4/conv4_1/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:17,CurrentNode:vgg_16/conv4/conv4_1/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:18,CurrentNode:vgg_16/conv4/conv4_2/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:19,CurrentNode:vgg_16/conv4/conv4_2/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:20,CurrentNode:vgg_16/conv4/conv4_3/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:21,CurrentNode:vgg_16/conv4/conv4_3/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:22,CurrentNode:vgg_16/conv5/conv5_1/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:23,CurrentNode:vgg_16/conv5/conv5_1/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:24,CurrentNode:vgg_16/conv5/conv5_2/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:25,CurrentNode:vgg_16/conv5/conv5_2/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:26,CurrentNode:vgg_16/conv5/conv5_3/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:27,CurrentNode:vgg_16/conv5/conv5_3/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:28,CurrentNode:vgg_16/fc6/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:29,CurrentNode:vgg_16/fc6/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:30,CurrentNode:vgg_16/fc7/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:31,CurrentNode:vgg_16/fc7/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:32,CurrentNode:vgg_16/fc8/weights\n",
      "INFO:root:TotalNum:86,TraslatedNum:33,CurrentNode:vgg_16/fc8/biases\n",
      "INFO:root:TotalNum:86,TraslatedNum:34,CurrentNode:vgg_16/conv1/conv1_1/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:35,CurrentNode:vgg_16/conv1/conv1_1/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:36,CurrentNode:vgg_16/conv1/conv1_1/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:37,CurrentNode:vgg_16/conv1/conv1_2/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:38,CurrentNode:vgg_16/conv1/conv1_2/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:39,CurrentNode:vgg_16/conv1/conv1_2/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:40,CurrentNode:vgg_16/pool1/MaxPool\n",
      "INFO:root:TotalNum:86,TraslatedNum:41,CurrentNode:vgg_16/conv2/conv2_1/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:42,CurrentNode:vgg_16/conv2/conv2_1/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:43,CurrentNode:vgg_16/conv2/conv2_1/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:44,CurrentNode:vgg_16/conv2/conv2_2/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:45,CurrentNode:vgg_16/conv2/conv2_2/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:46,CurrentNode:vgg_16/conv2/conv2_2/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:47,CurrentNode:vgg_16/pool2/MaxPool\n",
      "INFO:root:TotalNum:86,TraslatedNum:48,CurrentNode:vgg_16/conv3/conv3_1/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:49,CurrentNode:vgg_16/conv3/conv3_1/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:50,CurrentNode:vgg_16/conv3/conv3_1/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:51,CurrentNode:vgg_16/conv3/conv3_2/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:52,CurrentNode:vgg_16/conv3/conv3_2/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:53,CurrentNode:vgg_16/conv3/conv3_2/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:54,CurrentNode:vgg_16/conv3/conv3_3/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:55,CurrentNode:vgg_16/conv3/conv3_3/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:56,CurrentNode:vgg_16/conv3/conv3_3/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:57,CurrentNode:vgg_16/pool3/MaxPool\n",
      "INFO:root:TotalNum:86,TraslatedNum:58,CurrentNode:vgg_16/conv4/conv4_1/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:59,CurrentNode:vgg_16/conv4/conv4_1/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:60,CurrentNode:vgg_16/conv4/conv4_1/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:61,CurrentNode:vgg_16/conv4/conv4_2/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:62,CurrentNode:vgg_16/conv4/conv4_2/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:63,CurrentNode:vgg_16/conv4/conv4_2/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:64,CurrentNode:vgg_16/conv4/conv4_3/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:65,CurrentNode:vgg_16/conv4/conv4_3/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:66,CurrentNode:vgg_16/conv4/conv4_3/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:67,CurrentNode:vgg_16/pool4/MaxPool\n",
      "INFO:root:TotalNum:86,TraslatedNum:68,CurrentNode:vgg_16/conv5/conv5_1/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:69,CurrentNode:vgg_16/conv5/conv5_1/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:70,CurrentNode:vgg_16/conv5/conv5_1/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:71,CurrentNode:vgg_16/conv5/conv5_2/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:72,CurrentNode:vgg_16/conv5/conv5_2/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:73,CurrentNode:vgg_16/conv5/conv5_2/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:74,CurrentNode:vgg_16/conv5/conv5_3/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:75,CurrentNode:vgg_16/conv5/conv5_3/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:76,CurrentNode:vgg_16/conv5/conv5_3/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:77,CurrentNode:vgg_16/pool5/MaxPool\n",
      "INFO:root:TotalNum:86,TraslatedNum:78,CurrentNode:vgg_16/fc6/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:79,CurrentNode:vgg_16/fc6/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:80,CurrentNode:vgg_16/fc6/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:81,CurrentNode:vgg_16/fc7/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:82,CurrentNode:vgg_16/fc7/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:83,CurrentNode:vgg_16/fc7/Relu\n",
      "INFO:root:TotalNum:86,TraslatedNum:84,CurrentNode:vgg_16/fc8/Conv2D\n",
      "INFO:root:TotalNum:86,TraslatedNum:85,CurrentNode:vgg_16/fc8/BiasAdd\n",
      "INFO:root:TotalNum:86,TraslatedNum:86,CurrentNode:vgg_16/fc8/squeezed\n",
      "INFO:root:Model translated!\n"
     ]
    }
   ],
   "source": [
    "import tf2fluid.convert as convert\n",
    "import argparse\n",
    "parser = convert._get_parser()\n",
    "parser.meta_file = \"checkpoint/model.meta\"\n",
    "parser.ckpt_dir = \"checkpoint\"\n",
    "parser.in_nodes = [\"inputs\"]\n",
    "parser.input_shape = [\"None,224,224,3\"]\n",
    "parser.output_nodes = [\"vgg_16/fc8/squeezed\"]\n",
    "parser.use_cuda = \"True\"\n",
    "parser.input_format = \"NHWC\"\n",
    "parser.save_dir = \"paddle_model\"\n",
    "\n",
    "convert.run(parser)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载转换后的PaddlePaddle模型，并进行预测\n",
    "需要注意的是，转换后的PaddlePaddle CV模型**输入格式为NCHW**\n",
    "\n",
    "**注意：下面代码用于运行转换后的PaddlePaddle模型，并与TensorFlow计算结果对比diff，因此依赖PaddlePaddle**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-15T05:51:40.544737Z",
     "start_time": "2019-03-15T05:51:27.857863Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy\n",
    "import tf2fluid.model_loader as ml\n",
    "\n",
    "model = ml.ModelLoader(\"paddle_model\", use_cuda=False)\n",
    "\n",
    "numpy.random.seed(13)\n",
    "data = numpy.random.rand(5, 224, 224, 3).astype(\"float32\")\n",
    "# NHWC -> NCHW\n",
    "data = numpy.transpose(data, (0, 3, 1, 2))\n",
    "\n",
    "results = model.inference(feed_dict={model.inputs[0]:data})\n",
    "\n",
    "numpy.save(\"paddle.npy\", numpy.array(results))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 对比转换前后模型之前的预测结果diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-15T05:52:02.126718Z",
     "start_time": "2019-03-15T05:52:02.115849Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.67572e-06\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "paddle_result = numpy.load(\"paddle.npy\")\n",
    "tensorflow_result = numpy.load(\"tensorflow.npy\")\n",
    "diff = numpy.fabs(paddle_result - tensorflow_result)\n",
    "print(numpy.max(diff))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 需要注意的点\n",
    "1. 转换后的模型需要注意输入格式，PaddlePaddle中输入格式需为NCHW格式  \n",
    "2. 此例中不涉及到输入中间层，如卷积层的输出，需要了解的是PaddlePaddle中的卷积层输出，卷积核的`shape`与Tensorflow有差异  \n",
    "3. 模型转换完后，检查转换前后模型的diff，在本例中，测试得到的最大diff满足转换需求  "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:GPU-Paddle]",
   "language": "python",
   "name": "conda-env-GPU-Paddle-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
